{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fccdc384-f963-4407-8d64-22d20aab1d56",
   "metadata": {
    "id": "XIyP_0r6zuVc"
   },
   "source": [
    "<!-- Banner Image -->\n",
    "<img src=\"https://uohmivykqgnnbiouffke.supabase.co/storage/v1/object/public/landingpage/ocr2.png?t=2023-11-09T00%3A26%3A25.198Z\" width=\"100%\">\n",
    "\n",
    "<!-- Links -->\n",
    "<center>\n",
    "  <a href=\"https://console.brev.dev\" style=\"color: #06b6d4;\">Console</a> •\n",
    "  <a href=\"https://brev.dev\" style=\"color: #06b6d4;\">Docs</a> •\n",
    "  <a href=\"/\" style=\"color: #06b6d4;\">Templates</a> •\n",
    "  <a href=\"https://discord.gg/NVDyv7TUgJ\" style=\"color: #06b6d4;\">Discord</a>\n",
    "</center>\n",
    "\n",
    "# Predicting Protein-Protein Interactions Using a Protein Language Model and Linear Sum Assignment\n",
    "\n",
    "Welcome! **This is the notebook version of [this post](https://huggingface.co/blog/AmelieSchreiber/protein-binding-partners-with-esm2) by [Amelie Schreiber](https://huggingface.co/blog/AmelieSchreiber/protein-binding-partners-with-esm2).**\n",
    "\n",
    "#### In this notebook and tutorial, we'll use ESM-2, a **protein language model**, to score pairs of proteins using **masked language modeling loss** in order to **predict pairs of proteins that have a high likelihood of binding to one another**.\n",
    "\n",
    "First, let's get some background on two major topics we'll cover in this notebook: protein language models, and masked language modeling loss.\n",
    "\n",
    "## A. Protein Language Models:\n",
    "\n",
    "Protein language models in biology are computational models that apply the principles of natural language processing (NLP) to the \"language\" of proteins, which is the sequence of amino acids that form a protein. These models treat the sequences of amino acids in proteins similarly to how conventional NLP models treat words in a sentence. The idea is to capture and predict the complex patterns of protein structure, function, and interactions based on their amino acid sequences.\n",
    "\n",
    "### How They Work:\n",
    "1. **Sequence Representation**: Just as words are the basic units of language, amino acids are the basic units of proteins. Protein language models represent proteins as sequences of amino acids, using the one-letter codes (e.g., A for Alanine, R for Arginine) to represent each amino acid.\n",
    "\n",
    "2. **Learning Patterns**: These models are trained on large databases of known protein sequences, learning patterns and relationships between the amino acids in a sequence. They use algorithms similar to those used in NLP, such as transformers and recurrent neural networks, to capture the contextual relationships between amino acids in a sequence.\n",
    "\n",
    "3. **Embeddings**: Protein language models generate embeddings for amino acid sequences, which are high-dimensional vector representations that capture the contextual relationships and properties of the sequence. These embeddings can then be used to predict the structure, function, or interactions of the protein.\n",
    "\n",
    "### Uses:\n",
    "- **Protein Structure Prediction**: One of the primary applications is predicting the three-dimensional structure of proteins based on their amino acid sequences. Understanding the structure of a protein is crucial for elucidating its function and for drug discovery efforts.\n",
    "\n",
    "- **Function Prediction**: These are models like AlphaFold, which can predict the function of proteins by learning from the vast amounts of annotated protein databases. This is particularly useful for newly discovered proteins whose functions are unknown. \n",
    "\n",
    "- **Protein Engineering**: By understanding how changes in amino acid sequences affect protein structure and function, these models can be used to design proteins with desired properties, such as increased stability or novel functionalities.\n",
    "\n",
    "- **Drug Discovery**: Protein language models can help in identifying potential drug targets and in designing molecules that interact with proteins in specific ways to modulate their function.\n",
    "\n",
    "- **Understanding Disease Mechanisms**: They can be used to study how genetic mutations affecting protein sequences lead to changes in protein function and contribute to diseases. This can help in identifying potential therapeutic targets.\n",
    "\n",
    "In summary, protein language models are a powerful tool in computational biology and bioinformatics, offering insights into protein structure, function, and interactions that are fundamental for biological research and pharmaceutical development.\n",
    "\n",
    "\n",
    "## B. Masked Language Modeling Loss:\n",
    "\n",
    "Masked Language Modeling (MLM) loss is a concept derived from training techniques used in natural language processing (NLP) models, particularly in the context of transformer-based architectures like BERT (Bidirectional Encoder Representations from Transformers). Although originally developed for text data, the concept can also be applied to other sequences, such as proteins in computational biology, as mentioned earlier. Here, I'll explain the concept primarily from the NLP perspective, but the underlying principles are broadly applicable.\n",
    "\n",
    "\n",
    "### What is Masked Language Modeling?\n",
    "Masked Language Modeling is a training strategy where a certain percentage of the input tokens (e.g., words in a sentence) are randomly masked, or hidden from the model, during training. The model's task is to predict these masked tokens based only on the context provided by the unmasked tokens. This approach encourages the model to learn a deep, contextual understanding of the language.\n",
    "\n",
    "### How MLM Loss is Calculated:\n",
    "1. **Token Masking**: In the input sequence, a set percentage of tokens are replaced with a special [MASK] token, although variations of this technique might leave the token unchanged or replace it with a random token a certain percentage of the time to improve robustness.\n",
    "\n",
    "2. **Model Prediction**: The model processes the altered sequence and tries to predict the original token at each masked position. It generates a probability distribution over the entire vocabulary for each masked token, indicating how likely each token is to be the correct replacement.\n",
    "\n",
    "3. **Loss Calculation**: The MLM loss for a given masked token is calculated by comparing the model's predicted probability distribution against the true token. This is typically done using a loss function suitable for classification tasks, such as Cross-Entropy Loss. The MLM loss for the entire sequence is the average loss across all masked tokens.\n",
    "\n",
    "4. **Optimization**: The model parameters are updated to minimize the MLM loss. Through this process, the model learns to use contextual information to predict the masked tokens accurately.\n",
    "\n",
    "### Purpose and Benefits:\n",
    "- **Contextual Understanding**: MLM forces the model to learn context-dependent representations of tokens, as it must use the surrounding tokens to predict the masked ones. This leads to a rich understanding of language (or sequences in other domains).\n",
    "\n",
    "- **Bidirectional Context**: Unlike traditional language models that predict each token based on the preceding tokens (left-to-right or right-to-left), MLM allows the model to use both left and right context, resulting in more robust embeddings.\n",
    "\n",
    "- **Pretraining for Downstream Tasks**: MLM is often used for pretraining language models on large text corpora. The pretrained models can then be fine-tuned on smaller, task-specific datasets, significantly improving performance on a wide range of NLP tasks.\n",
    "\n",
    "In summary, Masked Language Modeling loss is a crucial component of training strategies that aim to develop models capable of understanding the intricate patterns and relationships within sequences, whether they be in natural language texts or biological sequences like proteins.\n",
    "\n",
    "\n",
    "## C. Predicting Protein-Protein Interactions with MLM Loss + Protein Language Models\n",
    "In this session, we use the protein language model ESM-2. \n",
    "### 1. Understanding ESM-2 and MLM Loss:\n",
    "- **ESM-2**: This model is designed to capture the complex patterns of amino acid sequences that define a protein's structure and function. It uses deep learning to understand the 'language' of proteins, learning from vast databases of known protein sequences.\n",
    "- **MLM Loss**: In the context of protein modeling, MLM loss measures how well the model predicts the identity of masked (hidden) amino acids in a sequence based on the surrounding context. The loss is lower when the model's predictions are close to the actual sequences.\n",
    "\n",
    "### 2. Predicting Protein-Protein Interactions:\n",
    "- **Sequence Pairing**: To predict PPIs, pairs of protein sequences are analyzed together. This can involve creating a concatenated sequence from two proteins.\n",
    "\n",
    "- **Masking Strategy**: Amino acids in one or both proteins might be masked, and the model predicts these masked residues based on the context provided by both sequences. This process evaluates how the presence of one protein sequence affects the prediction accuracy for the other, inferring interaction likelihood.\n",
    "\n",
    "- **MLM Loss Comparison**: By comparing the MLM loss of different protein pairings, the model can infer potential interactions. A lower MLM loss when two proteins are analyzed together versus separately suggests that the model finds a coherent or complementary context between them, indicating a potential interaction.\n",
    "\n",
    "### 3. Interpretation and Validation:\n",
    "- **Interpretation**: A significant drop in MLM loss for a specific protein pair suggests that the model recognizes a meaningful relationship between their sequences, which could reflect a real biological interaction.\n",
    "\n",
    "- **Validation**: Predicted interactions can be validated through experimental techniques such as co-immunoprecipitation or yeast two-hybrid assays, providing empirical evidence for the model's predictions.\n",
    "\n",
    "\n",
    "\n",
    "## This Notebook\n",
    "\n",
    "In the paper [Pairing interacting protein sequences using masked language modeling](https://arxiv.org/abs/2308.07136), the authors propose a method that uses either of two protein language models, [MSA Transformer](https://huggingface.co/models?other=MSA) or [ESM-2](https://huggingface.co/docs/transformers/model_doc/esm), to predict how likely it is that a pair of proteins bind to one another. In this post, we will focus on ESM-2. The method is very simple:\n",
    "1. We take a list of proteins we would like to test for interactions, then concatenate them in pairs.\n",
    "2. We use the masked language modeling capabilities of ESM-2 and randomly mask residues, then compute the MLM loss.\n",
    "3. We average over several iterations of this for each pair of proteins, obtaining a score that indicates how likely two proteins are to bind to one another.\n",
    "4. We then compute a matrix of such scores.\n",
    "5. Using this matrix are able to solve the associated linear assignment problem to compute optimal binding partners.\n",
    "\n",
    "#### Predicting protein-protein interactions is a critical task in molecular biology. Here, we'll use the ESM-2 model from Meta AI to compute Masked Language Model (MLM) loss for protein pairs, aiming to find the pairs with the lowest loss. The rationale is that proteins that interact in reality will produce a lower MLM loss than those that don't.\n",
    "\n",
    "\n",
    "##### Help us make this tutorial better! Please provide feedback on the [Discord channel](https://discord.gg/pnCpkwU3G5) or on [X](https://x.com/harperscarroll)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6478f6cf-cb8c-4a2b-bad4-ebb655e21173",
   "metadata": {
    "id": "hWI-uRLEyRgb"
   },
   "source": [
    "## Let's begin!\n",
    "\n",
    "I used a GPU and dev environment from [brev.dev](https://brev.dev). Click the badge below to get your preconfigured instance:\n",
    "\n",
    "[![](https://uohmivykqgnnbiouffke.supabase.co/storage/v1/object/public/landingpage/brevdeploynavy.svg)](https://console.brev.dev/environment/new?instance=T4:g4dn.xlarge&diskStorage=120&name=protein-demo&python=3.10&cuda=12.1.1)\n",
    "\n",
    "Once you've checked out your machine and landed in your instance page, select the specs you'd like (I used **Python 3.10 and CUDA 12.1.1**; these should be preconfigured for you if you use the badge above) and click the \"Build\" button to build your verb container. Give this a few minutes.\n",
    "\n",
    "A few minutes after your model has started Running, click the 'Notebook' button on the top right of your screen once it illuminates (you may need to refresh the screen). You will be taken to a Jupyter Lab environment, where you can upload this Notebook.\n",
    "\n",
    "Note: You can connect your cloud credits (AWS or GCP) by clicking \"Org: \" on the top right, and in the panel that slides over, click \"Connect AWS\" or \"Connect GCP\" under \"Connect your cloud\" and follow the instructions linked to attach your credentials.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9fdb383-710a-4e2f-a7e1-3aed8bd2c588",
   "metadata": {},
   "source": [
    "## 1. Install Libraries"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c801cbcb-0e06-4985-b80f-8e76a25f66eb",
   "metadata": {},
   "source": [
    "Let's install the libraries we'll be using."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "56147c6c-1e4c-4a5e-8ca0-fc1be239cf14",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!pip install --upgrade numpy scipy transformers plotly jupyter ipywidgets jupyterlab_widgets -q\n",
    "!pip install torch torchvision torchaudio -q"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6534f931-efdd-47be-a48b-530070963f63",
   "metadata": {},
   "source": [
    "Later, we'll be using jupyter widgets, so we need to make sure a recent nodejs is installed and jupyter widgets are enabled for Jupyter Lab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5b28ada0-c551-4828-9bb3-0be1bc918464",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[38;5;79m2024-02-08 03:05:24 - Installing pre-requisites\u001b[0m\n",
      "Hit:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease\n",
      "Hit:2 https://deb.nodesource.com/node_18.x nodistro InRelease                  \n",
      "Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease                         \n",
      "Hit:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease               \n",
      "Hit:5 http://archive.ubuntu.com/ubuntu jammy-backports InRelease\n",
      "Hit:6 http://security.ubuntu.com/ubuntu jammy-security InRelease\n",
      "Reading package lists... Done\n",
      "Reading package lists... Done\n",
      "Building dependency tree... Done\n",
      "Reading state information... Done\n",
      "ca-certificates is already the newest version (20230311ubuntu0.22.04.1).\n",
      "curl is already the newest version (7.81.0-1ubuntu1.15).\n",
      "gnupg is already the newest version (2.2.27-3ubuntu2.1).\n",
      "apt-transport-https is already the newest version (2.4.11).\n",
      "0 upgraded, 0 newly installed, 0 to remove and 42 not upgraded.\n",
      "Hit:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease\n",
      "Hit:2 http://archive.ubuntu.com/ubuntu jammy InRelease                         \n",
      "Hit:3 https://deb.nodesource.com/node_18.x nodistro InRelease                  \n",
      "Hit:4 http://archive.ubuntu.com/ubuntu jammy-updates InRelease                 \n",
      "Hit:5 http://archive.ubuntu.com/ubuntu jammy-backports InRelease             \n",
      "Hit:6 http://security.ubuntu.com/ubuntu jammy-security InRelease\n",
      "Reading package lists... Done\n",
      "\u001b[1;32m2024-02-08 03:05:30 - Repository configured successfully. To install Node.js, run: apt-get install nodejs -y\u001b[0m\n",
      "Hit:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease\n",
      "Hit:2 https://deb.nodesource.com/node_18.x nodistro InRelease       \u001b[0m\n",
      "Hit:3 http://security.ubuntu.com/ubuntu jammy-security InRelease    \n",
      "Hit:4 http://archive.ubuntu.com/ubuntu jammy InRelease\n",
      "Hit:5 http://archive.ubuntu.com/ubuntu jammy-updates InRelease\n",
      "Hit:6 http://archive.ubuntu.com/ubuntu jammy-backports InRelease\n",
      "Reading package lists... Done\u001b[33m\u001b[33m\u001b[33m\n",
      "Building dependency tree... Done\n",
      "Reading state information... Done\n",
      "42 packages can be upgraded. Run 'apt list --upgradable' to see them.\n",
      "Reading package lists... Done\n",
      "Building dependency tree... Done\n",
      "Reading state information... Done\n",
      "nodejs is already the newest version (18.19.0-1nodesource1).\n",
      "0 upgraded, 0 newly installed, 0 to remove and 42 not upgraded.\n"
     ]
    }
   ],
   "source": [
    "!curl -fsSL https://deb.nodesource.com/setup_18.x | sudo -E bash -\n",
    "!sudo apt update && sudo apt install nodejs -y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "89ddc744-926a-4a6f-ba5e-3c90915f88ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[LabCleanApp] Cleaning /home/ubuntu/.pyenv/versions/3.10.13/share/jupyter/lab...\n",
      "[LabCleanApp] staging not present, skipping...\n",
      "[LabCleanApp] Success!\n",
      "\u001b[33m(Deprecated) Installing extensions with the jupyter labextension install command is now deprecated and will be removed in a future major version of JupyterLab.\n",
      "\n",
      "Users should manage prebuilt extensions with package managers like pip and conda, and extension authors are encouraged to distribute their extensions as prebuilt packages \u001b[0m\n",
      "Building jupyterlab assets (production, minimized)\n"
     ]
    }
   ],
   "source": [
    "!jupyter lab clean\n",
    "!jupyter labextension install @jupyter-widgets/jupyterlab-manager"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1a9f54a-1507-4ab5-bc0e-4bea811475ea",
   "metadata": {},
   "source": [
    "Make sure this outputs version >= 18.0.0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "490fc1aa-6f85-4f88-a7c3-e233f30ac543",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "v18.19.0\n"
     ]
    }
   ],
   "source": [
    "!node -v"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a03d695d-41d9-4a47-8b1b-72caed509f08",
   "metadata": {},
   "source": [
    "Now, restart Jupyter Lab. \n",
    "Exit out of this window. In a Terminal on your laptop, type `brev notebook protein-demo` or `brev notebook <GPU_NAME>` if the GPU name you chose is different than `protein-demo`. \n",
    "Then, click the link that appears in the shell (i.e. Terminal window) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc383958-ec2c-4a25-8ee6-a8be38ac84a1",
   "metadata": {},
   "source": [
    "## 1.1 Import Libraries"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36566638-4edd-4f39-b81a-3fb47dd6d07c",
   "metadata": {},
   "source": [
    "Now let's import the necessary libraries.\n",
    "\n",
    "- **numpy**: Used for numerical operations, especially for handling matrices.\n",
    "- **linear_sum_assignment**: This is an optimization algorithm from the SciPy library that solves the linear sum assignment problem. We'll use this to find the optimal pairing of proteins based on the MLM loss.\n",
    "- **transformers**: This is the Hugging Face's library that provides pre-trained models for various NLP tasks. We're using it to load the ESM-2 model and its tokenizer.\n",
    "- **torch**: The PyTorch library, on which the transformers library is built."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "154e6b22-8d4c-484d-95a8-b5c51f7fe959",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.optimize import linear_sum_assignment\n",
    "from transformers import AutoTokenizer, EsmForMaskedLM\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b56169c1-1a41-4aba-8320-d88c02fd5fdc",
   "metadata": {},
   "source": [
    "## 2. Initialize the Model & Tokenizer\n",
    "Here, we load the Meta (f.k.a. Facebook) ESM-2 model using the Hugging Face Transformers library.\n",
    "- **tokenizer**: This is used to convert protein sequences into a format suitable for the model.\n",
    "- **model**: This is the ESM-2 model, specifically built for protein sequences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b5729711-2089-47b8-9a1e-0213efb1f1b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a94148d380374fafb27526d08079b040",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ad3422f5cf3445b29b413f1f2c28d290",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e36988c53a6f41f784194ed4e0e7e7d7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bdc77a968de941068cb0e1abe73bc941",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/775 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4e0112768666405c8557e8ccb537d55d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/31.4M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t6_8M_UR50D\")\n",
    "model = EsmForMaskedLM.from_pretrained(\"facebook/esm2_t6_8M_UR50D\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "344ec9bc-0733-4bef-b282-10d32e2445d5",
   "metadata": {},
   "source": [
    "## 3. Set up the GPU\n",
    "Using a GPU can greatly speed up computations, and Brev is great for easy, on-demand access to GPUs! The following code checks if a GPU is available (it is, if you're using Brev!) and sets the model to run on it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b8246b4d-a7fa-4e77-a29b-d8f4a567fc83",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9098e3c-2f77-4a97-9b99-c71f288615c3",
   "metadata": {},
   "source": [
    "## 4. Define the Protein Sequences\n",
    "\n",
    "For this tutorial, we've chosen a set of protein sequences for our analysis. Below are the chosen 6 protein sequences: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f5a68fb3-0ffb-4d9b-8ee4-1ebc954e4c38",
   "metadata": {},
   "outputs": [],
   "source": [
    "list_of_protein_sequences = [\n",
    "    \"MEESQSELNIDPPLSQETFSELWNLLPENNVLSSELCPAVDELLLPESVVNWLDEDSDDAPRMPATSAPTAPGPAPSWPLSSSVPSPKTYPGTYGFRLGFLHSGTAKSVTWTYSPLLNKLFCQLAKTCPVQLWVSSPPPPNTCVRAMAIYKKSEFVTEVVRRCPHHERCSDSSDGLAPPQHLIRVEGNLRAKYLDDRNTFRHSVVVPYEPPEVGSDYTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNVLGRNSFEVRVCACPGRDRRTEEENFHKKGEPCPEPPPGSTKRALPPSTSSSPPQKKKPLDGEYFTLQIRGRERYEMFRNLNEALELKDAQSGKEPGGSRAHSSHLKAKKGQSTSRHKKLMFKREGLDSD\",\n",
    "    \"MCNTNMSVPTDGAVTTSQIPASEQETLVRPKPLLLKLLKSVGAQKDTYTMKEVLFYLGQYIMTKRLYDEKQQHIVYCSNDLLGDLFGVPSFSVKEHRKIYTMIYRNLVVVNQQESSDSGTSVSENRCHLEGGSDQKDLVQELQEEKPSSSHLVSRPSTSSRRRAISETEENSDELSGERQRKRHKSDSISLSFDESLALCVIREICCERSSSSESTGTPSNPDLDAGVSEHSGDWLDQDSVSDQFSVEFEVESLDSEDYSLSEEGQELSDEDDEVYQVTVYQAGESDTDSFEEDPEISLADYWKCTSCNEMNPPLPSHCNRCWALRENWLPEDKGKDKGEISEKAKLENSTQAEEGFDVPDCKKTIVNDSRESCVEENDDKITQASQSQESEDYSQPSTSSSIIYSSQEDVKEFEREETQDKEESVESSLPLNAIEPCVICQGRPKNGCIVHGKTGHLMACFTCAKKLKKRNKPCPVCRQPIQMIVLTYFP\",\n",
    "    \"MNRGVPFRHLLLVLQLALLPAATQGKKVVLGKKGDTVELTCTASQKKSIQFHWKNSNQIKILGNQGSFLTKGPSKLNDRADSRRSLWDQGNFPLIIKNLKIEDSDTYICEVEDQKEEVQLLVFGLTANSDTHLLQGQSLTLTLESPPGSSPSVQCRSPRGKNIQGGKTLSVSQLELQDSGTWTCTVLQNQKKVEFKIDIVVLAFQKASSIVYKKEGEQVEFSFPLAFTVEKLTGSGELWWQAERASSSKSWITFDLKNKEVSVKRVTQDPKLQMGKKLPLHLTLPQALPQYAGSGNLTLALEAKTGKLHQEVNLVVMRATQLQKNLTCEVWGPTSPKLMLSLKLENKEAKVSKREKAVWVLNPEAGMWQCLLSDSGQVLLESNIKVLPTWSTPVQPMALIVLGGVAGLLLFIGLGIFFCVRCRHRRRQAERMSQIKRLLSEKKTCQCPHRFQKTCSPI\",\n",
    "    \"MRVKEKYQHLWRWGWKWGTMLLGILMICSATEKLWVTVYYGVPVWKEATTTLFCASDAKAYDTEVHNVWATHACVPTDPNPQEVVLVNVTENFNMWKNDMVEQMHEDIISLWDQSLKPCVKLTPLCVSLKCTDLGNATNTNSSNTNSSSGEMMMEKGEIKNCSFNISTSIRGKVQKEYAFFYKLDIIPIDNDTTSYTLTSCNTSVITQACPKVSFEPIPIHYCAPAGFAILKCNNKTFNGTGPCTNVSTVQCTHGIRPVVSTQLLLNGSLAEEEVVIRSANFTDNAKTIIVQLNQSVEINCTRPNNNTRKSIRIQRGPGRAFVTIGKIGNMRQAHCNISRAKWNATLKQIASKLREQFGNNKTIIFKQSSGGDPEIVTHSFNCGGEFFYCNSTQLFNSTWFNSTWSTEGSNNTEGSDTITLPCRIKQFINMWQEVGKAMYAPPISGQIRCSSNITGLLLTRDGGNNNNGSEIFRPGGGDMRDNWRSELYKYKVVKIEPLGVAPTKAKRRVVQREKRAVGIGALFLGFLGAAGSTMGARSMTLTVQARQLLSGIVQQQNNLLRAIEAQQHLLQLTVWGIKQLQARILAVERYLKDQQLLGIWGCSGKLICTTAVPWNASWSNKSLEQIWNNMTWMEWDREINNYTSLIHSLIEESQNQQEKNEQELLELDKWASLWNWFNITNWLWYIKIFIMIVGGLVGLRIVFAVLSIVNRVRQGYSPLSFQTHLPTPRGPDRPEGIEEEGGERDRDRSIRLVNGSLALIWDDLRSLCLFSYHRLRDLLLIVTRIVELLGRRGWEALKYWWNLLQYWSQELKNSAVSLLNATAIAVAEGTDRVIEVVQGACRAIRHIPRRIRQGLERILL\",\n",
    "    \"MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAEDLQVGQVELGGGPGAGSLQPLALEGSLQKRGIVEQCCTSICSLYQLENYCN\", \n",
    "    \"MATGGRRGAAAAPLLVAVAALLLGAAGHLYPGEVCPGMDIRNNLTRLHELENCSVIEGHLQILLMFKTRPEDFRDLSFPKLIMITDYLLLFRVYGLESLKDLFPNLTVIRGSRLFFNYALVIFEMVHLKELGLYNLMNITRGSVRIEKNNELCYLATIDWSRILDSVEDNYIVLNKDDNEECGDICPGTAKGKTNCPATVINGQFVERCWTHSHCQKVCPTICKSHGCTAEGLCCHSECLGNCSQPDDPTKCVACRNFYLDGRCVETCPPPYYHFQDWRCVNFSFCQDLHHKCKNSRRQGCHQYVIHNNKCIPECPSGYTMNSSNLLCTPCLGPCPKVCHLLEGEKTIDSVTSAQELRGCTVINGSLIINIRGGNNLAAELEANLGLIEEISGYLKIRRSYALVSLSFFRKLRLIRGETLEIGNYSFYALDNQNLRQLWDWSKHNLTITQGKLFFHYNPKLCLSEIHKMEEVSGTKGRQERNDIALKTNGDQASCENELLKFSYIRTSFDKILLRWEPYWPPDFRDLLGFMLFYKEAPYQNVTEFDGQDACGSNSWTVVDIDPPLRSNDPKSQNHPGWLMRGLKPWTQYAIFVKTLVTFSDERRTYGAKSDIIYVQTDATNPSVPLDPISVSNSSSQIILKWKPPSDPNGNITHYLVFWERQAEDSELFELDYCLKGLKLPSRTWSPPFESEDSQKHNQSEYEDSAGECCSCPKTDSQILKELEESSFRKTFEDYLHNVVFVPRKTSSGTGAEDPRPSRKRRSLGDVGNVTVAVPTVAAFPNTSSTSVPTSPEEHRPFEKVVNKESLVISGLRHFTGYRIELQACNQDTPEERCSVAAYVSARTMPEAKADDIVGPVTHEIFENNVVHLMWQEPKEPNGLIVLYEVSYRRYGDEELHLCVSRKHFALERGCRLRGLSPGNYSVRIRATSLAGNGSWTEPTYFYVTDYLDVPSNIAKIIIGPLIFVFLFSVVIGSIYLFLRKRQPDGPLGPLYASSNPEYLSASDVFPCSVYVPDEWEVSR\" \n",
    "]  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d904fdf8-dfcc-4d9b-9c9c-acc157e86b32",
   "metadata": {},
   "source": [
    "## 5. Define the MLM Loss Function\n",
    "\n",
    "Our goal is to find which protein pairs produce the lowest MLM loss. We'll define a function that computes the MLM loss for a batch of protein pairs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8611bce1-fc17-4d90-a7fc-fea824cc5555",
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 2\n",
    "NUM_MASKS = 10\n",
    "P_MASK = 0.15\n",
    "\n",
    "# Function to compute MLM loss for a batch of protein pairs\n",
    "def compute_mlm_loss_batch(pairs):\n",
    "    avg_losses = []\n",
    "    for _ in range(NUM_MASKS):\n",
    "        # Tokenize the concatenated protein pairs\n",
    "        inputs = tokenizer(pairs, return_tensors=\"pt\", truncation=True, padding=True, max_length=1022)\n",
    "        \n",
    "        # Move input tensors to GPU if available\n",
    "        inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "        \n",
    "        # Get the mask token ID\n",
    "        mask_token_id = tokenizer.mask_token_id\n",
    "        \n",
    "        # Clone input IDs for labels\n",
    "        labels = inputs[\"input_ids\"].clone()\n",
    "        \n",
    "        # Randomly mask 15% of the residues for each sequence in the batch\n",
    "        for idx in range(inputs[\"input_ids\"].shape[0]):\n",
    "            mask_indices = np.random.choice(inputs[\"input_ids\"].shape[1], size=int(P_MASK * inputs[\"input_ids\"].shape[1]), replace=False)\n",
    "            inputs[\"input_ids\"][idx, mask_indices] = mask_token_id\n",
    "            labels[idx, [i for i in range(inputs[\"input_ids\"].shape[1]) if i not in mask_indices]] = -100\n",
    "        \n",
    "        # Compute the MLM loss\n",
    "        outputs = model(**inputs, labels=labels)\n",
    "        avg_losses.append(outputs.loss.item())\n",
    "    \n",
    "    # Return the average loss for the batch\n",
    "    return sum(avg_losses) / NUM_MASKS"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d12ad8c-5404-4bec-8c08-c25522fabc60",
   "metadata": {},
   "source": [
    "Here, it is important to note that using a larger model with a longer *context window* (i.e. window of active knowledge) will likely improve results, but will also require more compute. If you want to use a larger model, with a longer context window, you might consider using one of these other [ESM-2 models](https://huggingface.co/facebook/esm2_t36_3B_UR50D), for example, esm2_t36_3B_UR50D, and you can get a larger GPU from Brev. You should also try adjusting `max_length` in the above code. The above context window isn't really sufficient for the long proteins we have chosen here, and using the larger models with longer context window, or using smaller proteins will almost certainly yield better results, but this should get you started."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ceb7275b-70de-4afb-99f7-4c9612e6f3d8",
   "metadata": {
    "id": "VCJnpZoayRgq"
   },
   "source": [
    "## 6. Construct the Loss Matrix\n",
    "\n",
    "For each pair, the MLM loss is computed and stored in the `loss_matrix`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "798345ad-99a1-43d9-aad6-78fa37bb7311",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute loss matrix\n",
    "loss_matrix = np.zeros((len(list_of_protein_sequences), len(list_of_protein_sequences)))\n",
    "\n",
    "for i in range(len(list_of_protein_sequences)):\n",
    "    for j in range(i+1, len(list_of_protein_sequences), BATCH_SIZE):  # to avoid self-pairing and use batches\n",
    "        pairs = [list_of_protein_sequences[i] + list_of_protein_sequences[k] for k in range(j, min(j+BATCH_SIZE, len(list_of_protein_sequences)))]\n",
    "        batch_loss = compute_mlm_loss_batch(pairs)\n",
    "        for k in range(len(pairs)):\n",
    "            loss_matrix[i, j+k] = batch_loss\n",
    "            loss_matrix[j+k, i] = batch_loss  # the matrix is symmetric\n",
    "\n",
    "# Set the diagonal of the loss matrix to a large value to prevent self-pairings\n",
    "np.fill_diagonal(loss_matrix, np.inf)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "759945e9-3171-4d4c-ba28-afaee4bac7b6",
   "metadata": {},
   "source": [
    "## 7. Find Optimal Pairs\n",
    "Let's find the optimal assignment that minimizes the total MLM loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "aff5e30f-1ee9-4945-ad16-2847cc49996c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(0, 1), (1, 0), (2, 5), (3, 4), (4, 3), (5, 2)]\n"
     ]
    }
   ],
   "source": [
    "# Use the linear assignment problem to find the optimal pairing based on MLM loss\n",
    "rows, cols = linear_sum_assignment(loss_matrix)\n",
    "optimal_pairs = list(zip(rows, cols))\n",
    "\n",
    "print(optimal_pairs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "690021c2-4932-406a-8df5-dcce8e9088da",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "The output of the cell above should be \n",
    "`[(0, 1), (1, 0), (2, 5), (3, 4), (4, 3), (5, 2)]`.\n",
    "\n",
    "Recall that items are 0-indexed. So the first output, (0, 1), means that the 1st element of our `list_of_protein_sequences` is most likely matched with the 2nd element of our `list_of_protein_sequences`, i.e.: \"MEESQSELNIDPPLSQETFSELWNLLPENNVLSSELCPAVDELLLPESVVNWLDEDSDDAPRMPATSAPTAPGPAPSWPLSSSVPSPKTYPGTYGFRLGFLHSGTAKSVTWTYSPLLNKLFCQLAKTCPVQLWVSSPPPPNTCVRAMAIYKKSEFVTEVVRRCPHHERCSDSSDGLAPPQHLIRVEGNLRAKYLDDRNTFRHSVVVPYEPPEVGSDYTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNVLGRNSFEVRVCACPGRDRRTEEENFHKKGEPCPEPPPGSTKRALPPSTSSSPPQKKKPLDGEYFTLQIRGRERYEMFRNLNEALELKDAQSGKEPGGSRAHSSHLKAKKGQSTSRHKKLMFKREGLDSD\" and   \"MCNTNMSVPTDGAVTTSQIPASEQETLVRPKPLLLKLLKSVGAQKDTYTMKEVLFYLGQYIMTKRLYDEKQQHIVYCSNDLLGDLFGVPSFSVKEHRKIYTMIYRNLVVVNQQESSDSGTSVSENRCHLEGGSDQKDLVQELQEEKPSSSHLVSRPSTSSRRRAISETEENSDELSGERQRKRHKSDSISLSFDESLALCVIREICCERSSSSESTGTPSNPDLDAGVSEHSGDWLDQDSVSDQFSVEFEVESLDSEDYSLSEEGQELSDEDDEVYQVTVYQAGESDTDSFEEDPEISLADYWKCTSCNEMNPPLPSHCNRCWALRENWLPEDKGKDKGEISEKAKLENSTQAEEGFDVPDCKKTIVNDSRESCVEENDDKITQASQSQESEDYSQPSTSSSIIYSSQEDVKEFEREETQDKEESVESSLPLNAIEPCVICQGRPKNGCIVHGKTGHLMACFTCAKKLKKRNKPCPVCRQPIQMIVLTYFP\".\n",
    "\n",
    "### Sweet! These are the optimal pairs!\n",
    "\n",
    "Let's dive into other things you can do with code, like graphical 3D modeling."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e73e9d41-e70f-4679-846c-0fde3e2b2291",
   "metadata": {},
   "source": [
    "## 8. Build a PPI Network\n",
    "\n",
    "We can also build a protein-protein interaction (PPI) network based on this method. We simply create a graph from a threshold for the MLM loss computation for pairs of proteins. If the loss is below a certain threshold, then the graph has an edge between those two proteins. This provides a very fast way of approximating protein interactomes and finding candidate interaction. We can do this as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7b424ac8-4b6a-4385-80ce-64c6b043693b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximum loss (maximum threshold for slider): 8.742724609375\n"
     ]
    }
   ],
   "source": [
    "import networkx as nx\n",
    "import numpy as np\n",
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
    "import plotly.graph_objects as go\n",
    "from ipywidgets import interact\n",
    "from ipywidgets import widgets\n",
    "\n",
    "# Check if CUDA is available and set the default device accordingly\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Load the pretrained (or fine-tuned) ESM-2 model and tokenizer\n",
    "model_name = \"facebook/esm2_t6_8M_UR50D\"  # You can change this to your fine-tuned model\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForMaskedLM.from_pretrained(model_name)\n",
    "\n",
    "# Send the model to the device (GPU or CPU)\n",
    "model.to(device)\n",
    "\n",
    "# Ensure the model is in evaluation mode\n",
    "model.eval()\n",
    "\n",
    "# Define Protein Sequences (Replace with your list)\n",
    "all_proteins = [\n",
    "    \"MFLSILVALCLWLHLALGVRGAPCEAVRIPMCRHMPWNITRMPNHLHHSTQENAILAIEQYEELVDVNCSAVLRFFLCAMYAPICTLEFLHDPIKPCKSVCQRARDDCEPLMKMYNHSWPESLACDELPVYDRGVCISPEAIVTDLPEDVKWIDITPDMMVQERPLDVDCKRLSPDRCKCKKVKPTLATYLSKNYSYVIHAKIKAVQRSGCNEVTTVVDVKEIFKSSSPIPRTQVPLITNSSCQCPHILPHQDVLIMCYEWRSRMMLLENCLVEKWRDQLSKRSIQWEERLQEQRRTVQDKKKTAGRTSRSNPPKPKGKPPAPKPASPKKNIKTRSAQKRTNPKRV\",\n",
    "    \"MDAVEPGGRGWASMLACRLWKAISRALFAEFLATGLYVFFGVGSVMRWPTALPSVLQIAITFNLVTAMAVQVTWKASGAHANPAVTLAFLVGSHISLPRAVAYVAAQLVGATVGAALLYGVMPGDIRETLGINVVRNSVSTGQAVAVELLLTLQLVLCVFASTDSRQTSGSPATMIGISVALGHLIGIHFTGCSMNPARSFGPAIIIGKFTVHWVFWVGPLMGALLASLIYNFVLFPDTKTLAQRLAILTGTVEVGTGAGAGAEPLKKESQPGSGAVEMESV\", \n",
    "    \"MKFLLDILLLLPLLIVCSLESFVKLFIPKRRKSVTGEIVLITGAGHGIGRLTAYEFAKLKSKLVLWDINKHGLEETAAKCKGLGAKVHTFVVDCSNREDIYSSAKKVKAEIGDVSILVNNAGVVYTSDLFATQDPQIEKTFEVNVLAHFWTTKAFLPAMTKNNHGHIVTVASAAGHVSVPFLLAYCSSKFAAVGFHKTLTDELAALQITGVKTTCLCPNFVNTGFIKNPSTSLGPTLEPEEVVNRLMHGILTEQKMIFIPSSIAFLTTLERILPERFLAVLKRKISVKFDAVIGYKMKAQ\", \n",
    "    \n",
    "    \"MAAAVPRRPTQQGTVTFEDVAVNFSQEEWCLLSEAQRCLYRDVMLENLALISSLGCWCGSKDEEAPCKQRISVQRESQSRTPRAGVSPKKAHPCEMCGLILEDVFHFADHQETHHKQKLNRSGACGKNLDDTAYLHQHQKQHIGEKFYRKSVREASFVKKRKLRVSQEPFVFREFGKDVLPSSGLCQEEAAVEKTDSETMHGPPFQEGKTNYSCGKRTKAFSTKHSVIPHQKLFTRDGCYVCSDCGKSFSRYVSFSNHQRDHTAKGPYDCGECGKSYSRKSSLIQHQRVHTGQTAYPCEECGKSFSQKGSLISHQLVHTGEGPYECRECGKSFGQKGNLIQHQQGHTGERAYHCGECGKSFRQKFCFINHQRVHTGERPYKCGECGKSFGQKGNLVHHQRGHTGERPYECKECGKSFRYRSHLTEHQRLHTGERPYNCRECGKLFNRKYHLLVHERVHTGERPYACEVCGKLFGNKHSVTIHQRIHTGERPYECSECGKSFLSSSALHVHKRVHSGQKPYKCSECGKSFSECSSLIKHRRIHTGERPYECTKCGKTFQRSSTLLHHQSSHRRKAL\", \n",
    "    \"MGQPWAAGSTDGAPAQLPLVLTALWAAAVGLELAYVLVLGPGPPPLGPLARALQLALAAFQLLNLLGNVGLFLRSDPSIRGVMLAGRGLGQGWAYCYQCQSQVPPRSGHCSACRVCILRRDHHCRLLGRCVGFGNYRPFLCLLLHAAGVLLHVSVLLGPALSALLRAHTPLHMAALLLLPWLMLLTGRVSLAQFALAFVTDTCVAGALLCGAGLLFHGMLLLRGQTTWEWARGQHSYDLGPCHNLQAALGPRWALVWLWPFLASPLPGDGITFQTTADVGHTAS\", \n",
    "    \"MGLRIHFVVDPHGWCCMGLIVFVWLYNIVLIPKIVLFPHYEEGHIPGILIIIFYGISIFCLVALVRASITDPGRLPENPKIPHGEREFWELCNKCNLMRPKRSHHCSRCGHCVRRMDHHCPWINNCVGEDNHWLFLQLCFYTELLTCYALMFSFCHYYYFLPLKKRNLDLFVFRHELAIMRLAAFMGITMLVGITGLFYTQLIGIITDTTSIEKMSNCCEDISRPRKPWQQTFSEVFGTRWKILWFIPFRQRQPLRVPYHFANHV\", \n",
    "    \n",
    "    \"MLLLGAVLLLLALPGHDQETTTQGPGVLLPLPKGACTGWMAGIPGHPGHNGAPGRDGRDGTPGEKGEKGDPGLIGPKGDIGETGVPGAEGPRGFPGIQGRKGEPGEGAYVYRSAFSVGLETYVTIPNMPIRFTKIFYNQQNHYDGSTGKFHCNIPGLYYFAYHITVYMKDVKVSLFKKDKAMLFTYDQYQENNVDQASGSVLLHLEVGDQVWLQVYGEGERNGLYADNDNDSTFTGFLLYHDTN\", \n",
    "    \"MGLLAFLKTQFVLHLLVGFVFVVSGLVINFVQLCTLALWPVSKQLYRRLNCRLAYSLWSQLVMLLEWWSCTECTLFTDQATVERFGKEHAVIILNHNFEIDFLCGWTMCERFGVLGSSKVLAKKELLYVPLIGWTWYFLEIVFCKRKWEEDRDTVVEGLRRLSDYPEYMWFLLYCEGTRFTETKHRVSMEVAAAKGLPVLKYHLLPRTKGFTTAVKCLRGTVAAVYDVTLNFRGNKNPSLLGILYGKKYEADMCVRRFPLEDIPLDEKEAAQWLHKLYQEKDALQEIYNQKGMFPGEQFKPARRPWTLLNFLSWATILLSPLFSFVLGVFASGSPLLILTFLGFVGAASFGVRRLIGVTEIEKGSSYGNQEFKKKE\", \n",
    "    \"MDLAGLLKSQFLCHLVFCYVFIASGLIINTIQLFTLLLWPINKQLFRKINCRLSYCISSQLVMLLEWWSGTECTIFTDPRAYLKYGKENAIVVLNHKFEIDFLCGWSLSERFGLLGGSKVLAKKELAYVPIIGWMWYFTEMVFCSRKWEQDRKTVATSLQHLRDYPEKYFFLIHCEGTRFTEKKHEISMQVARAKGLPRLKHHLLPRTKGFAITVRSLRNVVSAVYDCTLNFRNNENPTLLGVLNGKKYHADLYVRRIPLEDIPEDDDECSAWLHKLYQEKDAFQEEYYRTGTFPETPMVPPRRPWTLVNWLFWASLVLYPFFQFLVSMIRSGSSLTLASFILVFFVASVGVRWMIGVTEIDKGSAYGNSDSKQKLND\", \n",
    "    \n",
    "    \"MALLLCFVLLCGVVDFARSLSITTPEEMIEKAKGETAYLPCKFTLSPEDQGPLDIEWLISPADNQKVDQVIILYSGDKIYDDYYPDLKGRVHFTSNDLKSGDASINVTNLQLSDIGTYQCKVKKAPGVANKKIHLVVLVKPSGARCYVDGSEEIGSDFKIKCEPKEGSLPLQYEWQKLSDSQKMPTSWLAEMTSSVISVKNASSEYSGTYSCTVRNRVGSDQCLLRLNVVPPSNKAGLIAGAIIGTLLALALIGLIIFCCRKKRREEKYEKEVHHDIREDVPPPKSRTSTARSYIGSNHSSLGSMSPSNMEGYSKTQYNQVPSEDFERTPQSPTLPPAKVAAPNLSRMGAIPVMIPAQSKDGSIV\", \n",
    "    \"MSYVFVNDSSQTNVPLLQACIDGDFNYSKRLLESGFDPNIRDSRGRTGLHLAAARGNVDICQLLHKFGADLLATDYQGNTALHLCGHVDTIQFLVSNGLKIDICNHQGATPLVLAKRRGVNKDVIRLLESLEEQEVKGFNRGTHSKLETMQTAESESAMESHSLLNPNLQQGEGVLSSFRTTWQEFVEDLGFWRVLLLIFVIALLSLGIAYYVSGVLPFVENQPELVH\", \n",
    "    \"MRVAGAAKLVVAVAVFLLTFYVISQVFEIKMDASLGNLFARSALDTAARSTKPPRYKCGISKACPEKHFAFKMASGAANVVGPKICLEDNVLMSGVKNNVGRGINVALANGKTGEVLDTKYFDMWGGDVAPFIEFLKAIQDGTIVLMGTYDDGATKLNDEARRLIADLGSTSITNLGFRDNWVFCGGKGIKTKSPFEQHIKNNKDTNKYEGWPEVVEMEGCIPQKQD\", \n",
    "    \n",
    "    \"MAPAAATGGSTLPSGFSVFTTLPDLLFIFEFIFGGLVWILVASSLVPWPLVQGWVMFVSVFCFVATTTLIILYIIGAHGGETSWVTLDAAYHCTAALFYLSASVLEALATITMQDGFTYRHYHENIAAVVFSYIATLLYVVHAVFSLIRWKSS\", \n",
    "    \"MRLQGAIFVLLPHLGPILVWLFTRDHMSGWCEGPRMLSWCPFYKVLLLVQTAIYSVVGYASYLVWKDLGGGLGWPLALPLGLYAVQLTISWTVLVLFFTVHNPGLALLHLLLLYGLVVSTALIWHPINKLAALLLLPYLAWLTVTSALTYHLWRDSLCPVHQPQPTEKSD\", \n",
    "    \"MEESVVRPSVFVVDGQTDIPFTRLGRSHRRQSCSVARVGLGLLLLLMGAGLAVQGWFLLQLHWRLGEMVTRLPDGPAGSWEQLIQERRSHEVNPAAHLTGANSSLTGSGGPLLWETQLGLAFLRGLSYHDGALVVTKAGYYYIYSKVQLGGVGCPLGLASTITHGLYKRTPRYPEELELLVSQQSPCGRATSSSRVWWDSSFLGGVVHLEAGEKVVVRVLDERLVRLRDGTRSYFGAFMV\"\n",
    "]\n",
    "\n",
    "def compute_average_mlm_loss(protein1, protein2, iterations=10):\n",
    "    total_loss = 0.0\n",
    "    connector = \"G\" * 25  # Connector sequence of G's\n",
    "    for _ in range(iterations):\n",
    "        concatenated_sequence = protein1 + connector + protein2\n",
    "        inputs = tokenizer(concatenated_sequence, return_tensors=\"pt\", padding=True, truncation=True, max_length=1024)\n",
    "        \n",
    "        mask_prob = 0.55\n",
    "        mask_indices = torch.rand(inputs[\"input_ids\"].shape, device=device) < mask_prob\n",
    "        \n",
    "        # Locate the positions of the connector 'G's and set their mask indices to False\n",
    "        connector_indices = tokenizer.encode(connector, add_special_tokens=False)\n",
    "        connector_length = len(connector_indices)\n",
    "        start_connector = len(tokenizer.encode(protein1, add_special_tokens=False))\n",
    "        end_connector = start_connector + connector_length\n",
    "        \n",
    "        # Avoid masking the connector 'G's\n",
    "        mask_indices[0, start_connector:end_connector] = False\n",
    "        \n",
    "        # Apply the mask to the input IDs\n",
    "        inputs[\"input_ids\"][mask_indices] = tokenizer.mask_token_id\n",
    "        inputs = {k: v.to(device) for k, v in inputs.items()}  # Send inputs to the device\n",
    "\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**inputs, labels=inputs[\"input_ids\"])\n",
    "        \n",
    "        loss = outputs.loss\n",
    "        total_loss += loss.item()\n",
    "    \n",
    "    return total_loss / iterations\n",
    "\n",
    "# Compute all average losses to determine the maximum threshold for the slider\n",
    "all_losses = []\n",
    "for i, protein1 in enumerate(all_proteins):\n",
    "    for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):\n",
    "        avg_loss = compute_average_mlm_loss(protein1, protein2)\n",
    "        all_losses.append(avg_loss)\n",
    "\n",
    "# Set the maximum threshold to the maximum loss computed\n",
    "max_threshold = max(all_losses)\n",
    "print(f\"Maximum loss (maximum threshold for slider): {max_threshold}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "920f8349-67dc-45b2-bf0f-5d2f3f478843",
   "metadata": {},
   "source": [
    "Now, let's print an interactive 3D graph representing the PPI network. You can adjust the threshold slider for the MLM loss value to make the requirements for interactions more or less strict."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "acdac422-bf81-4110-99b8-9cb37a634ed0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_graph(threshold):\n",
    "    G = nx.Graph()\n",
    "\n",
    "    # Add all protein nodes to the graph\n",
    "    for i, protein in enumerate(all_proteins):\n",
    "        G.add_node(f\"protein {i+1}\")\n",
    "\n",
    "    # Loop through all pairs of proteins and calculate average MLM loss\n",
    "    loss_idx = 0  # Index to keep track of the position in the all_losses list\n",
    "    for i, protein1 in enumerate(all_proteins):\n",
    "        for j, protein2 in enumerate(all_proteins[i+1:], start=i+1):\n",
    "            avg_loss = all_losses[loss_idx]\n",
    "            loss_idx += 1\n",
    "            \n",
    "            # Add an edge if the loss is below the threshold\n",
    "            if avg_loss < threshold:\n",
    "                G.add_edge(f\"protein {i+1}\", f\"protein {j+1}\", weight=round(avg_loss, 3))\n",
    "\n",
    "    # 3D Network Plot\n",
    "    # Adjust the k parameter to bring nodes closer. This might require some experimentation to find the right value.\n",
    "    k_value = 2  # Lower value will bring nodes closer together\n",
    "    pos = nx.spring_layout(G, dim=3, seed=42, k=k_value)\n",
    "\n",
    "    edge_x = []\n",
    "    edge_y = []\n",
    "    edge_z = []\n",
    "    for edge in G.edges():\n",
    "        x0, y0, z0 = pos[edge[0]]\n",
    "        x1, y1, z1 = pos[edge[1]]\n",
    "        edge_x.extend([x0, x1, None])\n",
    "        edge_y.extend([y0, y1, None])\n",
    "        edge_z.extend([z0, z1, None])\n",
    "    \n",
    "    edge_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(width=0.5, color='grey'))\n",
    "    \n",
    "    node_x = []\n",
    "    node_y = []\n",
    "    node_z = []\n",
    "    node_text = []\n",
    "    for node in G.nodes():\n",
    "        x, y, z = pos[node]\n",
    "        node_x.append(x)\n",
    "        node_y.append(y)\n",
    "        node_z.append(z)\n",
    "        node_text.append(node)\n",
    "    \n",
    "    node_trace = go.Scatter3d(x=node_x, y=node_y, z=node_z, mode='markers', marker=dict(size=5), hoverinfo='text', hovertext=node_text)\n",
    "    \n",
    "    layout = go.Layout(title='Protein Interaction Graph', title_x=0.5, scene=dict(xaxis=dict(showbackground=False), yaxis=dict(showbackground=False), zaxis=dict(showbackground=False)))\n",
    "\n",
    "    fig = go.Figure(data=[edge_trace, node_trace], layout=layout)\n",
    "    fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "863cf35f-93c4-453f-b1a2-050a61c0a1ba",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fd5f7b4ce95d424aa8e3d9cd3e57a6df",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(FloatSlider(value=8.25, description='threshold', max=8.742724609375, step=0.05), Output(…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<function __main__.plot_graph(threshold)>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Create an interactive slider for the threshold value with a default of 8.50\n",
    "interact(plot_graph, threshold=widgets.FloatSlider(min=0.0, max=max_threshold, step=0.05, value=8.25))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96fd9ec5-b964-4aad-a464-497a4a24eafb",
   "metadata": {},
   "source": [
    "So, for example, try setting the slider to 8.20 or 8.30 and see what kind of predicted interactome results from this choice of MLM loss threshold. You should also adjust the amount of masked tokens to see how this effects the graph. In general, masking more residues will make the threshold necessary for connections to appear in the predicted PPI graph higher. Note, this code may take a few moments to run, since we are computing the loss for all pairs of proteins in our list, but in general the method is a very fast zero shot method for predicting PPI networks. As a next step you might also consider training the model to predict masked residues of known interacting pairs in order to finetune it to this task further. Another interesting and important question to answer is how the length of the proteins effect this computation. Is the method robust to large variations in lengths, or do the proteins need to be of similar lengths for the method to work?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "470f673e-fbef-423e-91bb-6f64e1a593db",
   "metadata": {},
   "source": [
    "# Conclusion\n",
    "Now you should be able to use the ESM-2 model to predict potential protein-protein interactions by comparing the MLM loss of different protein pairings. This method provides a novel way of inferring interactions using deep learning techniques. Remember, this approach provides a heuristic and should be combined with experimental validation for conclusive results. As a next step, you might try implementing the ideas in [PepMLM: Target Sequence-Conditioned Generation of Peptide Binders via Masked Language Modeling](https://arxiv.org/abs/2310.03842), which finetunes ESM-2 for generating binding partners using masked language modeling, or you might try finetuning ESM-2 in a similar fashion on concatenated pairs of binding partners, with some percentage of the tokens in each binding partner masked, to see if performance is improved.\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
